from abc import ABC, abstractmethod

import numba
import numpy as np
import heapq


class Learner(ABC):

    @abstractmethod
    def act(self):
        raise NotImplementedError

    @abstractmethod
    def update(self, arm, feedback):
        raise NotImplementedError

    def set_rng(self, rng):
        self.rng = np.random.default_rng(rng)


class RandomPull(Learner):

    def __init__(self, n_arms):
        self.n_arms = n_arms

    def update(self, arm, feedback):
        pass

    def act(self):
        return np.random.randint(self.n_arms)


class GaussTS(Learner):

    def __init__(self, n_arms, noise_var=1, prior_mean=0, prior_var=1, rng=42):
        self.n_arms = n_arms
        self.noise_var = noise_var
        self.mu0 = prior_mean * np.ones(n_arms)
        self.var0 = prior_var * np.ones(n_arms)
        self.set_rng(rng)
        # for attr, val in params.items():
        #     setattr(self, attr, val)
        self.pulls = np.zeros(n_arms)  # number of pulls
        self.reward = np.zeros(n_arms)  # cumulative reward
        self.counter = 0
        self.arm_his = []

    def update(self, arm, feedback):
        self.pulls[arm] += 1
        self.reward[arm] += np.sum(feedback)

    def act(self):
        if self.counter < self.n_arms:
            # each arm is initially pulled once
            self.counter += 1
            arm = self.counter-1
            self.arm_his.append(arm)
            return arm
        # posterior distribution
        post_var = 1.0 / (1.0 / self.var0 + self.pulls / self.noise_var)
        post_mean = post_var * (self.mu0 / self.var0 + self.reward / self.noise_var)

        # posterior sampling
        noises = self.rng.normal(size=self.n_arms)
        self.mu = post_mean + np.sqrt(post_var) * noises
        arm = np.argmax(self.mu)
        self.arm_his.append(arm)
        return arm


class UCB(Learner):

    def __init__(self, n_arms, radius=2, logt=False):
        self.n_arms = n_arms
        # radius is log(1/delta') in the paper
        self.radius = np.sqrt(2*radius)
        self.logt = logt
        self.counter = 1
        self.cumul_rewards = np.zeros(n_arms)
        self.n_pulls = np.zeros(n_arms)
        self.UCBs = np.repeat(float('inf'), n_arms)
        self.arm_his = []

    def update(self, arm, reward):
        self.cumul_rewards[arm] += reward
        self.n_pulls[arm] += 1
        estimate_arm = self.cumul_rewards[arm]/self.n_pulls[arm]
        if not self.logt:
            rad = self._get_radius(arm)
            self.UCBs[arm] = estimate_arm + rad
        elif np.min(self.n_pulls) == 0:
            self.UCBs[arm] = 0
        else:
            radii = self._get_radii()
            estimates_arm = self.cumul_rewards/self.n_pulls
            self.UCBs = estimates_arm + radii
        self.counter += 1

    def act(self):
        pulled = np.argmax(self.UCBs)
        self.arm_his.append(pulled)
        return pulled

    def _get_radius(self, arm):
        if not self.logt:
            return self.radius/np.sqrt(self.n_pulls[arm])
        return self.radius*np.sqrt(np.log(self.counter)/self.n_pulls[arm])

    def _get_radii(self):
        if not self.logt:
            return self.radius/np.sqrt(self.n_pulls)
        return self.radius*np.sqrt(np.log(self.counter)/self.n_pulls)


class SumUCB(UCB):

    def __init__(self, n_arms, n_variables, radius=2, logt=False):
        self.n_variables = n_variables
        super(SumUCB, self).__init__(n_arms, radius, logt)

    def update(self, arm, feedback):
        super(SumUCB, self).update(arm, np.sum(feedback))

    def _get_radius(self, arm):
        return self.n_variables * super()._get_radius(arm)

    def _get_radii(self):
        return self.n_variables * super()._get_radii()


class UpUCB(UCB):

    def __init__(self, n_arms, n_variables, affected_sets,
                 baseline=None, baseline_option='UCB',
                 radius=2, logt=False):
        super(UpUCB, self).__init__(n_arms, radius, logt)
        self.n_variables = n_variables
        self.affected_sets = affected_sets
        self.ns_affected = np.array([len(s) for s in affected_sets])
        self.UPs = np.repeat(float('inf'), n_arms)
        self.baseline = baseline
        self.baseline_option = baseline_option
        if baseline is None:
            self.baseline_cumuls = np.zeros(n_variables)
            self.baseline_pulls = np.zeros(n_variables)
            self.unaffected_sets = [
                list(set(range(n_variables)) - set(s)) for s in affected_sets]
        else:
            self.baseline_per_arm = np.zeros(n_arms)
            for arm in range(n_arms):
                affected = np.ix_(self.affected_sets[arm])
                self.baseline_per_arm[arm] = np.sum(self.baseline[affected])

    def update(self, arm, feedback):
        affected = np.ix_(self.affected_sets[arm])
        reward = np.sum(feedback[affected])
        # The counter is already increased by 1 here
        super(UpUCB, self).update(arm, reward)
        if self.baseline is None:
            unaffected = np.ix_(self.unaffected_sets[arm])
            self.baseline_cumuls[unaffected] += feedback[unaffected]
            self.baseline_pulls[unaffected] += 1
            self.baseline_estimates = self.baseline_cumuls / np.maximum(1, self.baseline_pulls)
            radius = (self.radius * np.sqrt(np.log(self.counter))
                      if self.logt else self.radius)
            if self.baseline_option == 'LCB':
                self.baseline_cmps = (
                    self.baseline_estimates
                    - radius / np.maximum(1, np.sqrt(self.baseline_pulls)))
            elif self.baseline_option == 'est':
                self.baseline_cmps = self.baseline_estimates
            else:
                self.baseline_cmps = (
                    self.baseline_estimates
                    + radius / np.maximum(1, np.sqrt(self.baseline_pulls)))
            # self.UPs = compute_UPs(self.UCBs, self.baseline_cmps, self.affected_sets)
            for arm in range(self.n_arms):
                affected = np.ix_(self.affected_sets[arm])
                self.UPs[arm] = self.UCBs[arm] - np.sum(self.baseline_cmps[affected])
        else:
            self.UPs[arm] = self.UCBs[arm] - self.baseline_per_arm[arm]

    def act(self):
        pulled = np.argmax(self.UPs)
        self.arm_his.append(pulled)
        return pulled

    def _get_radius(self, arm):
        return self.ns_affected[arm] * super()._get_radius(arm)

    def _get_radii(self):
        return self.ns_affected * super()._get_radii()


@numba.jit(nopython=True)
def compute_UPs(ucbs, baseline, affected_sets):
    n_arms = len(ucbs)
    ups = np.zeros(n_arms)
    for arm in range(n_arms):
        baseline_arm = 0
        for i in affected_sets[arm]:
            baseline_arm += baseline[i]
        ups[arm] = ucbs[arm] - baseline_arm
    return ups


class UpUCB_L(UCB):

    def __init__(self, n_arms, n_variables, n_affected,
                 baseline=None, radius=2, logt=False):

        self.n_arms = n_arms
        self.n_variables = n_variables
        self.n_affected = n_affected
        self.baseline = baseline
        self.radius = np.sqrt(2*radius)
        self.logt = logt
        self.counter = 1

        self.cumuls_per_variable = np.zeros(shape=(n_arms, n_variables))
        self.n_pulls = np.zeros(n_arms)
        self.UPs = np.repeat(float('inf'), n_arms)
        self.identified_sets = [[] for _ in range(n_arms)]
        self.arm_his = []

    def update(self, arm, feedback):
        self.cumuls_per_variable[arm] += feedback
        self.n_pulls[arm] += 1
        if self.baseline is not None and not self.logt:
            radius = self._get_radius(arm)
            self._update_UP(arm, self.baseline, radius, 0, self.n_affected)
        # Initialization phase
        elif np.min(self.n_pulls) == 0:
            self.UPs[arm] = 0
        elif self.baseline is not None:
            for arm in range(self.n_arms):
                radius = self._get_radius(arm)
                self._update_UP(arm, self.baseline, radius, 0, self.n_affected)
        else:
            baseline_arm = np.argmax(self.n_pulls)
            baseline = self.cumuls_per_variable[baseline_arm]/self.n_pulls[baseline_arm]
            # radius_baseline = self.radius/np.sqrt(self.n_pulls[baseline_arm])
            radius_baseline = self._get_radius(baseline_arm)
            self.UPs[baseline_arm] = 0
            for arm in range(self.n_arms):
                if arm != baseline_arm:
                    radius = self._get_radius(arm)
                    self._update_UP(arm, baseline, radius, radius_baseline, 2*self.n_affected)
            self.baseline_arm = baseline_arm
            self.identified_sets[baseline_arm] = []
        self.counter += 1

    def act(self):
        pulled = np.argmax(self.UPs)
        self.arm_his.append(pulled)
        return pulled

    def _update_UP(self, arm, baseline, radius, radius_baseline, n_diff):
        estimates = self.cumuls_per_variable[arm]/self.n_pulls[arm]
        identified = np.abs(estimates - baseline) > radius + radius_baseline
        self.identified_sets[arm] = identified.nonzero()[0]
        UCBs = estimates + radius
        baseline_UCBs = baseline + radius_baseline
        self.UPs[arm] = np.sum((UCBs - baseline_UCBs)[identified])
        if np.sum(identified) < n_diff:
            # Find the unidentified elements with the largest uplifting indices
            diff = np.maximum(0, (UCBs - baseline_UCBs)[~identified])
            n_selected = n_diff - np.sum(identified)
            self.UPs[arm] += np.sum(heapq.nlargest(n_selected, diff))


class UpUCB_Gap(UpUCB):

    def __init__(self, n_arms, n_variables, minimum_lift, baseline,
                 radius=2, radius_n0=None):
        affected_sets = [np.arange(n_variables) for _ in range(n_arms)]
        super(UpUCB_Gap, self).__init__(
                n_arms, n_variables, affected_sets,
                baseline, radius=radius, logt=False)
        self.minimum_lift = minimum_lift
        if radius_n0 is None:
            radius_n0 = radius
        self.n0 = 8*radius_n0/(minimum_lift**2)
        self.cumuls_per_variable = np.zeros(shape=(n_arms, n_variables))

    def update(self, arm, feedback):
        self.cumuls_per_variable[arm] += feedback
        self.n_pulls[arm] += 1
        if self.n_pulls[arm] >= self.n0:
            estimates_arm = self.cumuls_per_variable[arm]/(self.n_pulls[arm])
            identified = np.abs(estimates_arm - self.baseline) > self.minimum_lift/2
            affected = np.sort(identified.nonzero()[0])
            if not np.array_equal(self.affected_sets[arm], affected):
                self.affected_sets[arm] = affected
                self.ns_affected[arm] = len(affected)
                self.baseline_per_arm[arm] = np.sum(self.baseline[affected])
        affected = np.ix_(self.affected_sets[arm])
        self.cumul_rewards[arm] = np.sum(self.cumuls_per_variable[arm][affected])
        estimate_arm = self.cumul_rewards[arm]/self.n_pulls[arm]
        rad = self._get_radius(arm)
        self.UCBs[arm] = estimate_arm + rad
        self.UPs[arm] = self.UCBs[arm] - self.baseline_per_arm[arm]
        self.counter += 1
